package com.example.demo.xss;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * Jsoup过滤 http请求，防止 Xss攻击
 *
 * @author MrBird
 */
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
	private static String key = "and|exec|insert|select|delete|update|count|*|%|chr|mid|master|truncate|char|declare|;|or|-|+";
	private static Set<String> notAllowedKeyWords = new HashSet<String>(0);
	private static String replacedString = "INVALID";
	
	static {
		String[] keyStr = key.split("\\|");
		for (String str : keyStr) {
			notAllowedKeyWords.add(str);
		}
	}
	
	private HttpServletRequest orgRequest;
	private String currentUrl;
	private boolean isIncludeRichText;
	
	XssHttpServletRequestWrapper(HttpServletRequest request, boolean isIncludeRichText) {
		super(request);
		orgRequest = request;
		currentUrl = request.getRequestURI();
		this.isIncludeRichText = isIncludeRichText;
	}
	
	/**
	 * /**覆盖getParameter方法，将参数名和参数值都做xss过滤。
	 * 如果需要获得原始的值，则通过super.getParameterValues(name)来获取
	 * getParameterNames,getParameterValues和getParameterMap也可能需要覆盖
	 */
	@Override
	public String getParameter(String parameter) {
		String value = super.getParameter(parameter);
		if (value == null) {
			return null;
		}
		return cleanXSS(value);
	}
	
	@Override
	public String[] getParameterValues(String parameter) {
		String[] values = super.getParameterValues(parameter);
		if (values == null) {
			return null;
		}
		int count = values.length;
		String[] encodedValues = new String[count];
		for (int i = 0; i < count; i++) {
			encodedValues[i] = cleanXSS(values[i]);
		}
		return encodedValues;
	}
	
	
	/**
	 * 获取原始的 request
	 */
	private HttpServletRequest getOrgRequest() {
		return orgRequest;
	}
	
	/**
	 * 获取原始的 request的静态方法
	 */
	public static HttpServletRequest getOrgRequest(HttpServletRequest req) {
		if (req instanceof XssHttpServletRequestWrapper) {
			return ((XssHttpServletRequestWrapper) req).getOrgRequest();
		}
		return req;
	}
	
	@Override
	public Map<String, String[]> getParameterMap() {
		Map<String, String[]> values = super.getParameterMap();
		if (values == null) {
			return null;
		}
		Map<String, String[]> result = new HashMap<>();
		for (String key : values.keySet()) {
			String encodedKey = cleanXSS(key);
			int count = values.get(key).length;
			String[] encodedValues = new String[count];
			for (int i = 0; i < count; i++) {
				encodedValues[i] = cleanXSS(values.get(key)[i]);
			}
			result.put(encodedKey, encodedValues);
		}
		return result;
	}
	
	
	private String cleanXSS(String valueP) {
		// You'll need to remove the spaces from the html entities below
		String value = valueP.replaceAll("<", "&lt;").replaceAll(">", "&gt;");
		value = value.replaceAll("<", "& lt;").replaceAll(">", "& gt;");
		value = value.replaceAll("\\(", "& #40;").replaceAll("\\)", "& #41;");
		value = value.replaceAll("'", "& #39;");
		value = value.replaceAll("eval\\((.*)\\)", "");
		value = value.replaceAll("[\\\"\\'][\\s]*javascript:(.*)[\\\"\\']", "\"\"");
		value = value.replaceAll("script", "");
		value = cleanSqlKeyWords(value);
		return value;
	}
	
	private String cleanSqlKeyWords(String value) {
		String paramValue = value;
		String paramValueLow = paramValue.toLowerCase();
		for (String keyword : notAllowedKeyWords) {
			if (paramValueLow.length() > keyword.length() + 4
					&& (paramValueLow.contains(" " + keyword) || paramValueLow.contains(keyword + " ") || paramValueLow.contains(" " + keyword + " "))) {
				paramValue = ignoreCaseReplace(paramValue, keyword, replacedString);
			}
		}
		return paramValue;
	}
	
	private String ignoreCaseReplace(String source, String oldstring,
	                                 String newstring) {
		Pattern p = Pattern.compile(oldstring, Pattern.CASE_INSENSITIVE);
		Matcher m = p.matcher(source);
		return m.replaceAll(newstring);
	}
}
